Conv2dTranspose

计算二维转置卷积,可以视为 Conv2d 对输入求梯度,也称为反卷积(实际不是真正的反卷积)。

输入的 shape 通常为 \((N, H_{in}, W_{in}, C_{in})\),其中:

  • \(N\) 是 batch size

  • \(C_{in}\) 是空间维度

  • \(H_{in}, W_{in}\) 分别为特征层的高度和宽度

输入:
  • input_x - 输入数据的地址

  • input_w - 输入卷积核权重的地址

  • bias - 输入偏置的地址

  • param - 算子计算所需参数的结构体。其各成员见下述。

  • core_mask - 核掩码。

ConvTransposeParameter定义:

 1typedef struct ConvTransposeParameter {
 2    void* workspace_; // 用于存放中间计算结果
 3    int output_batch_; // 输出数据总批次
 4    int input_batch_; // 输入数据总批次
 5    int input_h_; // 输入数据h维度大小
 6    int input_w_; // 输入数据w维度大小
 7    int output_h_; // 输出数据h维度大小
 8    int output_w_; // 输出数据w维度大小
 9    int input_channel_; // 输入数据通道数
10    int output_channel_; // 输出数据通道数
11    int kernel_h_; // 卷积核h维度大小
12    int kernel_w_; // 卷积核w维度大小
13    int group_; // 组数
14    int pad_l_; // 左填充大小
15    int pad_u_; // 上填充大小
16    int dilation_h_; // 卷积核h维度膨胀尺寸大小
17    int dilation_w_; // 卷积核w维度膨胀尺寸大小
18    int stride_h_; // 卷积核h维度步长
19    int stride_w_; // 卷积核w维度步长
20    int buffer_size_; // 为分块计算所分配的缓存大小
21} ConvTransposeParameter;
输出:
  • out_y - 输出地址。

支持平台:

FT78NE MT7004

备注

  • FT78NE 支持int8, fp32

  • MT7004 支持fp16, fp32

共享存储版本:

void i8_convtranspose_s(int8_t *input_x, int8_t *input_w, int8_t *out_y, int *bias, ConvTransposeParameter *conv_param, int core_mask)
void hp_convtranspose_s(half *input_x, half *input_w, half *out_y, half *bias, ConvTransposeParameter *conv_param, int core_mask)
void fp_convtranspose_s(float *input_x, float *input_w, float *out_y, float *bias, ConvTransposeParameter *conv_param, int core_mask)

C调用示例:

 1void TestConvTransposeSMCFp32(int* input_shape, int* weight_shape, int* output_shape, int* stride, int* padding, int* dilation, int groups, float* bias, int core_mask) {
 2    int core_id = get_core_id();
 3    int logic_core_id = GetLogicCoreId(core_mask, core_id);
 4    int core_num = GetCoreNum(core_mask);
 5    float* input_data = (float*)0x88000000;
 6    float* weight = (float*)0x89000000;
 7    float* output_data = (float*)0x90000000;
 8    float* bias_data = (float*)0x91000000;
 9    float* check = (float*)0x94000000;
10    ConvTransposeParameter* param = (ConvTransposeParameter*)0x92000000;
11    if (logic_core_id == 0) {
12        memcpy(bias_data, bias, sizeof(float) * output_shape[3]);
13        memset(output_data, 0, output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * sizeof(float));
14        memset(check, 0, output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * sizeof(float));
15        param->dilation_h_ = dilation[0];
16        param->dilation_w_ = dilation[1];
17        param->group_ = groups;
18        param->input_batch_ = input_shape[0];
19        param->input_h_ = input_shape[1];
20        param->input_w_ = input_shape[2];
21        param->input_channel_ = input_shape[3];
22        param->kernel_h_ = weight_shape[1];
23        param->kernel_w_ = weight_shape[2];
24        param->output_batch_ = output_shape[0];
25        param->output_h_ = output_shape[1];
26        param->output_w_ = output_shape[2];
27        param->output_channel_ = output_shape[3];
28        param->stride_h_ = stride[0];
29        param->stride_w_ = stride[0];
30        param->pad_u_ = padding[0];
31        param->pad_l_ = padding[2];
32        param->workspace_ = (float*)0xA0000000;
33    }
34    sys_bar(0, core_num); // 初始化参数完成后进行同步
35    fp_convtranspose_s(input_data, weight, output_data, bias_data, param, core_mask);
36}
37
38void main(){
39    int in_channel = 6;
40    int out_channel = 6;
41    int groups = 6;
42    int input_shape[4] = {2, 5, 7, in_channel}; // NHWC
43    int weight_shape[4] = {in_channel, 3, 3, out_channel / groups};
44    int output_shape[4] = {2, 7, 9, out_channel}; // NHWC
45    int stride[2] = {1, 1};
46    int padding[4] = {0, 0, 0, 0};
47    int dilation[2]= {1, 1};
48    float bias[] = {0, 0, 0, 0, 0, 0};
49    int core_mask = 0b1111;
50    TestConvTransposeSMCFp32(input_shape, weight_shape, output_shape, stride, padding, dilation, groups, bias, core_mask);
51}

私有存储版本:

void i8_convtranspose_p(int8_t *input_x, int8_t *input_w, int8_t *out_y, int *bias, ConvTransposeParameter *conv_param, int core_mask)
void hp_convtranspose_p(half *input_x, half *input_w, half *out_y, half *bias, ConvTransposeParameter *conv_param, int core_mask)
void fp_convtranspose_p(float *input_x, float *input_w, float *out_y, float *bias, ConvTransposeParameter *conv_param, int core_mask)

C调用示例:

 1void TestConvTransposeL2Fp32(int* input_shape, int* weight_shape, int* output_shape, int* stride, int* padding, int* dilation, int groups, float* bias, int core_mask) {
 2    float* input_data = (float*)0x10000000; // 私有存储版本地址设置在AM内
 3    float* weight = (float*)0x10001000;
 4    float* output_data = (float*)0x10002000;
 5    float* bias_data = (float*)0x10003000;
 6    float* check = (float*)0x10004000;
 7    ConvTransposeParameter* param = (ConvTransposeParameter*)0x10005000;
 8    memcpy(bias_data, bias, sizeof(float) * output_shape[3]);
 9    memset(output_data, 0, output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * sizeof(float));
10    memset(check, 0, output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * sizeof(float));
11    param->dilation_h_ = dilation[0];
12    param->dilation_w_ = dilation[1];
13    param->group_ = groups;
14    param->input_batch_ = input_shape[0];
15    param->input_h_ = input_shape[1];
16    param->input_w_ = input_shape[2];
17    param->input_channel_ = input_shape[3];
18    param->kernel_h_ = weight_shape[1];
19    param->kernel_w_ = weight_shape[2];
20    param->output_batch_ = output_shape[0];
21    param->output_h_ = output_shape[1];
22    param->output_w_ = output_shape[2];
23    param->output_channel_ = output_shape[3];
24    param->stride_h_ = stride[0];
25    param->stride_w_ = stride[0];
26    param->pad_u_ = padding[0];
27    param->pad_l_ = padding[2];
28    param->workspace_ = (float*)0x10006000;
29    param->buffer_size_ = 1024; // 私有存储版本中,必须设置该参数,用于确定分块计算的大小
30    fp_convtranspose_p(input_data, weight, output_data, bias_data, param, core_mask);
31}
32
33void main(){
34    int in_channel = 6;
35    int out_channel = 6;
36    int groups = 6;
37    int input_shape[4] = {2, 5, 7, in_channel}; // NHWC
38    int weight_shape[4] = {in_channel, 3, 3, out_channel / groups};
39    int output_shape[4] = {2, 7, 9, out_channel}; // NHWC
40    int stride[2] = {1, 1};
41    int padding[4] = {0, 0, 0, 0};
42    int dilation[2]= {1, 1};
43    float bias[] = {0, 0, 0, 0, 0, 0};
44    int core_mask = 0b0001; // 私有存储版本只能设置为一个核心启动
45    TestConvTransposeL2Fp32(input_shape, weight_shape, output_shape, stride, padding, dilation, groups, bias, core_mask);
46}